from data_provider.data_factory import data_provider
from exp.exp_basic import Exp_Basic
from utils.tools import EarlyStopping, adjust_learning_rate
from utils.metrics import metric
import torch
import torch.nn as nn
from torch import optim
import os
import re  
import time
import warnings
import numpy as np
import json
from sklearn.metrics import r2_score
import shap
import matplotlib.pyplot as plt
warnings.filterwarnings('ignore')
from sklearn.metrics import r2_score, mean_squared_error, mean_absolute_error
from utils.tools import EarlyStopping, adjust_learning_rate
from data_provider.data_loader import Dataset_Custom
from torch.utils.data import DataLoader
import os
import numpy as np
import torch
import pandas as pd  # 导入pandas库
class Exp_Long_Term_Forecast(Exp_Basic):
    def __init__(self, args):
        super(Exp_Long_Term_Forecast, self).__init__(args)
        self.args = args  # 确保 args 被正确传递

    def _build_model(self):
        model = self.model_dict[self.args.model].Model(self.args).float()

        if self.args.use_multi_gpu and self.args.use_gpu:
            model = nn.DataParallel(model, device_ids=self.args.device_ids)
        return model

    def _get_data(self, flag):
        data_set, data_loader = data_provider(self.args, flag)
        return data_set, data_loader

    def _select_optimizer(self):
        model_optim = optim.Adam(self.model.parameters(), lr=self.args.learning_rate)
        return model_optim

    def _select_criterion(self):
        if self.args.loss == 'MSE' or self.args.loss == 'mse':
            criterion = nn.MSELoss()
        elif self.args.loss == 'MAE' or self.args.loss == 'mae':
            criterion = nn.L1Loss()
        return criterion

    def vali(self, vali_data, vali_loader, criterion):
        total_loss = []
        self.model.eval()
        with torch.no_grad():
            preds=[]
            trues=[]
            for i, (batch_x, batch_y) in enumerate(vali_loader):
                batch_x = batch_x.float().to(self.device,non_blocking=True)
                batch_y = batch_y[:, -self.args.pred_len:,:].float()
                # encoder - decoder
                if self.args.use_amp:
                    with torch.cuda.amp.autocast():
                        outputs = self.model(batch_x)
                else:

                    outputs = self.model(batch_x)
                pred = outputs.detach().cpu().numpy()
                true = batch_y.detach().numpy()
                preds.append(pred)
                trues.append(true)
        if len(preds)>0:
            preds=np.concatenate(preds, axis=0)
            trues=np.concatenate(trues, axis=0)
        else:
            preds=preds[0]
            trues=trues[0]
        mse,mae= metric(preds, trues)
        vali_loss=mae if criterion == 'MAE' or criterion == 'mae' else mse
        self.model.train()
        torch.cuda.empty_cache()
        return vali_loss

    def train(self, setting):
        train_data, train_loader = self._get_data(flag='train')
        vali_data, vali_loader = self._get_data(flag='val')
        test_data, test_loader = self._get_data(flag='test')

        path = os.path.join(self.args.checkpoints, setting)
        if not os.path.exists(path):
            os.makedirs(path)

        train_steps = len(train_loader)
        early_stopping = EarlyStopping(patience=self.args.patience, verbose=True)

        model_optim = self._select_optimizer()
        criterion = self._select_criterion()

        if self.args.use_amp:
            scaler = torch.cuda.amp.GradScaler()

        for epoch in range(self.args.train_epochs):
            iter_count = 0
            train_loss = []

            self.model.train()
            epoch_time = time.time()
            for i, (batch_x, batch_y) in enumerate(train_loader):
                iter_count += 1
                model_optim.zero_grad(set_to_none=True)
                batch_x = batch_x.float().to(self.device,non_blocking=True)
                batch_y = batch_y[:, -self.args.pred_len:,:].float().to(self.device,non_blocking=True)
                # encoder - decoder
                if self.args.use_amp:
                    with torch.cuda.amp.autocast():
                        outputs = self.model(batch_x)
                        loss = criterion(outputs, batch_y)
                        train_loss.append(loss.item())
                else:
                    outputs = self.model(batch_x)
                    loss = criterion(outputs, batch_y)
                    train_loss.append(loss.item())
                if self.args.use_amp:
                    scaler.scale(loss).backward()
                    scaler.step(model_optim)
                    scaler.update()
                else:
                    loss.backward()
                    model_optim.step()
                torch.cuda.empty_cache()

            print("Epoch: {} cost time: {}".format(epoch + 1, time.time() - epoch_time))
            train_loss = np.average(train_loss)
            vali_loss= self.vali(vali_data, vali_loader, self.args.loss)
            test_loss = self.vali(test_data, test_loader, self.args.loss)
            print("Epoch: {}, Steps: {} | Train Loss: {:.3f}  vali_loss: {:.3f}   test_loss: {:.3f} ".format(epoch + 1, train_steps, train_loss,  vali_loss, test_loss))
            early_stopping(vali_loss, self.model, path)
            if early_stopping.early_stop:
                print("Early stopping")
                break

            adjust_learning_rate(model_optim, epoch + 1, self.args)
        torch.cuda.empty_cache()
    def shap_analysis(self):
        print("Starting SHAP analysis...")

        try:
            # Get the training data
            train_data, train_loader = self._get_data(flag='train')
            test_data, test_loader = self._get_data(flag='test')
            print(f"Train data size: {len(train_data)}")

            # Convert train_data to a tensor if it's not already
            train_data_tensor = torch.tensor(train_data.data_x, dtype=torch.float32).to(self.device)
            print(f"Train data tensor shape: {train_data_tensor.shape}")

            # Ensure the data has the correct shape (samples, features)
            if len(train_data_tensor.shape) != 2:
                raise ValueError(f"Unexpected shape of train_data_tensor: {train_data_tensor.shape}")

            # Get the dimensions
            n_samples, n_features = train_data_tensor.shape
            print(f"Data dimensions: samples={n_samples}, features={n_features}")

            # Convert the data to numpy and inverse transform
            train_data_2d = train_data_tensor.cpu().numpy()
            train_data_2d_inverse = train_data.inverse_transform(train_data_2d)
            print(f"Inverse transformed train_data_2d shape: {train_data_2d_inverse.shape}")

            # Select 1000 samples randomly
            sample_size = min(100, n_samples)
            sample_indices = np.random.choice(n_samples, sample_size, replace=False)
            sample_data = train_data_2d_inverse[sample_indices]
            print(f"Sample data shape: {sample_data.shape}")

            # Use a subset of sample data as background
            background_size = min(100, sample_size)
            background = sample_data[:background_size]
            print(f"Background shape: {background.shape}")

            # Ensure the model is in evaluation mode
            self.model.eval()

            # Create a wrapper function to normalize the input before passing it to the model
            def model_wrapper(x):
                x_normalized = torch.tensor(train_data.scaler.transform(x), dtype=torch.float32).to(self.device)
                return self.model(x_normalized)

            # Create a Deep SHAP explainer
            print("Creating DeepExplainer...")
            explainer = shap.DeepExplainer(model_wrapper, background)
            print("DeepExplainer created successfully.")

            # Calculate SHAP values for sample data
            print("Calculating SHAP values...")
            shap_values = explainer.shap_values(sample_data)
            print("SHAP values calculated successfully.")

            print(f"SHAP values shape: {np.array(shap_values).shape}")

            # Handle multi-output case
            if isinstance(shap_values, list):
                shap_values = np.array(shap_values)

            if len(shap_values.shape) == 3:  # (output_dim, samples, features)
                shap_values = shap_values[0]  # Use the first output dimension

            print(f"SHAP values shape after processing: {shap_values.shape}")

            # Calculate mean absolute SHAP values
            shap_values_mean_abs = np.mean(np.abs(shap_values), axis=0)
            shap_values_mean = np.mean(shap_values, axis=0)

            # Get feature names
            df_raw = pd.read_csv(os.path.join(self.args.root_path, self.args.data_path))
            cols_data = df_raw.columns[2:-1]  # Exclude the last column (TB incidence)

            df_data = df_raw[cols_data]

            feature_names = df_data.columns.tolist()
            print(f"All feature names: {feature_names}")

            # Sort features by importance, excluding the last feature (TB incidence)
            feature_importance_order = np.argsort(shap_values_mean_abs[:-1])[::-1]
            top_10_features = feature_importance_order[:10]
            print(f"Top 10 features: {top_10_features}")

            # Print feature importance ranking
            print("Feature importance ranking and SHAP values:")
            print("Index: Feature Name - Mean Absolute SHAP Value")
            for i, idx in enumerate(feature_importance_order):
                shap_value_abs = shap_values_mean_abs[idx]
                print(f"{i + 1}: {feature_names[idx]} - shap_value_abs: {shap_value_abs:.6f} - shap_value: {shap_values_mean[idx]:.6f}")

            # Print SHAP values for each feature
            print("\nSHAP values for each feature:")
            for feature_idx in range(n_features - 1):  # Exclude the last feature
                shap_value = shap_values_mean[feature_idx]
                print(f"Feature: {feature_names[feature_idx]}, SHAP Value: {shap_value:.6f}")

            # Plot SHAP summary plot (violin plot)
            print("Plotting SHAP summary plot (violin)...")
            plt.figure(figsize=(12, 8))
            shap.summary_plot(shap_values[:, :-1], sample_data[:, :-1], 
                            feature_names=feature_names,
                            plot_type="violin", show=False)

            plt.xlabel("SHAP value (impact on model output)", family='Times New Roman', fontsize=14)
            plt.rc('font', family='Times New Roman', size=15)
            plt.tight_layout()
            plt.savefig('shap_summary_plot-1.png')
            plt.close()
            print("SHAP summary plot saved as 'shap_summary_plot-1.png'")

            # Plot feature importance
            print("Plotting feature importance...")
            plt.figure(figsize=(12, 8))
            top_20_features = feature_importance_order[:20]  # Top 20 features
            top_20_feature_names = [feature_names[idx] for idx in top_20_features]
            plt.barh(top_20_feature_names, shap_values_mean_abs[top_20_features])
            plt.xlabel("mean(|SHAP value|) (average impact on model output magnitude)", fontfamily='Times New Roman', fontsize=14)
            plt.ylabel("Features", fontfamily='Times New Roman', fontsize=14)
            plt.title("Feature Importance", fontfamily='Times New Roman', fontsize=16)
            plt.yticks(fontfamily='Times New Roman', fontsize=12)
            plt.xticks(fontfamily='Times New Roman', fontsize=12)
            plt.gca().invert_yaxis()
            plt.tight_layout()
            plt.savefig('feature_importance_plot.png')
            plt.close()
            print("Feature importance plot saved as 'feature_importance_plot.png'")

            # Select top 10 features
            top_10_feature_names = [feature_names[i] for i in top_10_features]
            print(f"Top 10 feature names: {top_10_feature_names}")

            # Additional diagnostic functions
            self._check_shap_values_distribution(shap_values)
            self._check_model_predictions(sample_data)

            # Plot SHAP dependence plots for top features
            print("Plotting SHAP dependence plots for top features...")
            for feature_idx in top_10_features:
                plt.figure(figsize=(10, 6))
                sanitized_feature_name = re.sub(r'\W+', '_', feature_names[feature_idx])  # Sanitize feature name
                shap.dependence_plot(feature_idx, shap_values, sample_data, feature_names=feature_names, show=False)
                plt.tight_layout()
                plt.savefig(f'shap_dependence_plot_{sanitized_feature_name}.png')
                plt.close()
                print(f"SHAP dependence plot for {feature_names[feature_idx]} saved as 'shap_dependence_plot_{sanitized_feature_name}.png'")

        except Exception as e:
            print(f"An error occurred during SHAP analysis: {e}")

    def test(self, setting, test=1):
            test_data, test_loader = self._get_data(flag='test')
            path = os.path.join(self.args.checkpoints, setting)
            if test:
                print('loading model')
                self.model.load_state_dict(torch.load(os.path.join(path, 'checkpoint.pth')))
            
            head = f'./test_dict/{self.args.data_path[:-4]}/{self.args.seq_len}_to_{self.args.pred_len}/'
            tail = f'{self.args.model}/{self.args.loss}/bz_{self.args.batch_size}/lr_{self.args.learning_rate}/'
            dict_path = head + tail
            
            if not os.path.exists(dict_path):
                os.makedirs(dict_path)

            self.model.eval()
            with torch.no_grad():
                preds = []
                trues = []
                for i, (batch_x, batch_y) in enumerate(test_loader):
                    batch_x = batch_x.float().to(self.device, non_blocking=True)
                    batch_y = batch_y[:, -self.args.pred_len:, :].float()
                    if self.args.use_amp:
                        with torch.cuda.amp.autocast():
                            outputs = self.model(batch_x)
                    else:
                        outputs = self.model(batch_x)
                    outputs = outputs.detach().cpu().numpy()
                    batch_y = batch_y.detach().numpy()

                    preds.append(outputs)
                    trues.append(batch_y)
                
                preds = np.concatenate(preds, axis=0)
                trues = np.concatenate(trues, axis=0)

                # Reshape the arrays into 2D (if they are not already)
                preds = preds.reshape(-1, preds.shape[-1])
                trues = trues.reshape(-1, trues.shape[-1])
                
                print('test shape:', preds.shape, trues.shape)

                # Inverse transform the predictions and truths
                preds_inverse = test_data.inverse_transform(preds)
                trues_inverse = test_data.inverse_transform(trues)

                # Save to CSV
                results = pd.DataFrame({
                    'Truth': trues_inverse.flatten(),
                    'Prediction': preds_inverse.flatten()
                })
                results.to_csv(os.path.join(dict_path, 'test_results.csv'), index=False)
                        # Load original CSV to get target values
                # Load original CSV to get target values
                original_data = pd.read_csv(os.path.join(self.args.root_path, self.args.data_path))
                target_values = original_data[self.args.target].values

                # Compare and extract exact matches based on value and position
                matching_indices = [i for i, val in enumerate(target_values) if val in trues_inverse]
                matching_results = results.iloc[matching_indices]

                # Save matching results to new CSV
                matching_results.to_csv(os.path.join(dict_path, 'matching_test_results.csv'), index=False)

                mse, mae = mean_squared_error(trues, preds), mean_absolute_error(trues, preds)
                r2 = r2_score(trues, preds)
                print('mse: {:.3f}  mae: {:.3f}  r2: {:.3f}'.format(mse, mae, r2)) 

                my_dict = {
                    'mse': "{:.3f}".format(mse),
                    'mae': "{:.3f}".format(mae),
                    'r2': "{:.3f}".format(r2)
                }
                with open(os.path.join(dict_path, 'records.json'), 'w') as f:
                    json.dump(my_dict, f)
                torch.cuda.empty_cache()







    def predict(self, setting, predict_data, predict_loader):            
        path = os.path.join(self.args.checkpoints, setting)
            
        print('loading model')
        self.model.load_state_dict(torch.load(os.path.join(path, 'checkpoint.pth')))
            
        self.model.eval()
            
        preds = []
        trues = []
            
        with torch.no_grad():
            for i, (batch_x, batch_y) in enumerate(predict_loader):
                batch_x = batch_x.float().to(self.device, non_blocking=True)
                batch_y = batch_y[:, -self.args.pred_len:, :].float()       
                if self.args.use_amp:
                    with torch.cuda.amp.autocast():
                        outputs = self.model(batch_x)
                else:
                    outputs = self.model(batch_x)                   
                pred = outputs.detach().cpu().numpy()
                true = batch_y.detach().numpy()                    
                preds.append(pred)
                trues.append(true)            
        preds = np.concatenate(preds, axis=0)
        trues = np.concatenate(trues, axis=0)            
        preds = preds.reshape(-1, preds.shape[-1])
        trues = trues.reshape(-1, trues.shape[-1])
        mse = mean_squared_error(trues, preds)
        mae = mean_absolute_error(trues, preds)
        r2 = r2_score(trues, preds)
        print('Prediction Results:')
        print('MSE: {:.3f}, MAE: {:.3f}, R^2: {:.3f}'.format(mse, mae, r2))      
        return preds, trues, mse, mae, r2

